import matplotlib.pyplot as plt
from matplotlib.pyplot import imread, imshow
import math
import numpy as np
import os

def afficher_matrice_confusion(matrice_confusion, normaliser=False, afficher_valeurs=False, titre="Matrice de confusion",
                                taille_figure=(16, 14), cmap='Blues'):

   # Conversion en numpy array si nécessaire
    cm = np.array(matrice_confusion)

    # Normalisation si demandée
    if normaliser:
        cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
        fmt = '.2f'
    else:
        fmt = 'd'

    # Création de la figure
    fig, ax = plt.subplots(figsize=taille_figure)

    # Affichage de la matrice
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)

    # Configuration des axes
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=list(symboles_confusion[0]),
           yticklabels=list(symboles_confusion[0]),
           title=titre,
           ylabel='Caractère réel',
           xlabel='Caractère prédit')

    # Rotation des labels pour meilleure lisibilité
    plt.setp(ax.get_xticklabels(), rotation=90, ha="right", rotation_mode="anchor")

    # Ajustement de la taille de police pour 79 caractères
    ax.tick_params(axis='both', which='major', labelsize=7)

    # Affichage optionnel des valeurs (déconseillé pour 79x79)
    if afficher_valeurs:
        thresh = cm.max() / 2.
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                ax.text(j, i, format(cm[i, j], fmt),
                       ha="center", va="center",
                       color="white" if cm[i, j] > thresh else "black",
                       fontsize=4)

    # Ajustement de la mise en page avec plus de marge
    fig.tight_layout(pad=2.0)

    # Ajustement manuel des marges pour éviter la coupure
    plt.subplots_adjust(left=0.08, right=0.95, top=0.96, bottom=0.08)
    plt.show()



# Listes de symboles
categories = ["majuscules", "minuscules", "chiffres", "special"]
symboles = [
    "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
    "abcdefghijklmnopqrstuvwxyz",
    "0123456789",
    ".:,;'(!?)éèàçùêûâ",
]
symboles_confusion = ["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.:,;'(!?)éèàçùêûâ"]


# Images à lire : les images sont lues dans le répertoire <base_11x79>
# On lit 3 images du caractère "K"
#        2 images du caractère "P"
#        1 image du caractère "?"
liste_images_à_lire = ["Open Sans Light_majuscules18_10.png",
                       "Segoe UI Semilight_majuscules18_10.png",
                       "Pavanam_majuscules18_10.png",
                       "Sarabun ExtraLight_majuscules18_15.png",
                       "Microsoft Yi Baiti_majuscules18_15.png",
                       "Segoe UI Semilight_special18_7.png"]


def lire_symbole_fichier(nomFichier: str) -> str:
    car = nomFichier.split('_')
    num = car[2].split('.')[0]
    var = car[1][:len(car[1])-2]
    ind = categories.index(var)
    return symboles[ind][int(num)]

def liste_fichiers_repertoire(repertoire):
    liste_fichiers = []
    for nom_fichier in os.listdir(repertoire):
        liste_fichiers.append(nom_fichier)

    return liste_fichiers


# Question 2
def lire_donnees_ref(dossier, fichiers_car_ref):
    # Initialise le dictionnaire vide
    caract_ref = {}

    # Itération sur les fichiers
    for fichier in fichiers_car_ref:
        # Récupère le symbole
        symbole_fichier = lire_symbole_fichier(fichier)

        # Récupère l'array de l'image
        img = imread(dossier + "\\" + fichier)

        # Enregistre l'information dans le dictionnaire
        if symbole_fichier in caract_ref:
            caract_ref[symbole_fichier].append(img)
        else:
            caract_ref[symbole_fichier] = [img]
    return caract_ref

# Question 3 : Sans numpy
def distance(im1, im2):
    n_l, n_c = im1.shape          # Nombre de lignes / colonnes

    # Initialise la distance = 0
    d = 0

    # Calcul de la distance
    for i in range(n_l):
        for j in range(n_c):
            d = d + (im1[i][j] - im2[i][j])**2
    return math.sqrt(d)

# Question 4 : Avec numpy
def distance_np(im1, im2):
    return np.linalg.norm(im1-im2)

# Question 5 :
def calcul_distances(carac_ref, caract_test):
    # Initialise le dictionnaire des distances
    distances = {key:[] for key in carac_ref.keys()}

    # Itère sur les différents caractères stockés
    # dans le dictionnaire des tableaux catégorisé "carac_ref"
    for caractere in carac_ref:
        for arr in carac_ref[caractere]:
            dist = distance_np(arr, caract_test)
            distances[caractere].append(dist)
    return distances

# Question 7 : algorithme Kvoisins à compléter
def Kvoisins(distances:dict, K:int ) -> list :
    voisins = [(float("inf"),"") for k in range (K)]
    for lettre in distances:
        d = distances[lettre]
        for j in range (len(d)):
            if d[j] < voisins[-1][0]:
                k = len(voisins)-1
                while k > 0 and d[j] < voisins[k-1][0]:
                    voisins[k] = voisins[k-1]
                    k=k-1
                voisins[k] = [d[j], lettre]
    return voisins

# Question 9 : algorithme du vote majoritaire
def symbole_majoritaire(voisins):
    majoritaires = {}
    for dist,symbole in voisins:
        if symbole in majoritaires:
            majoritaires[symbole] = majoritaires[symbole] + 1
        else:
            majoritaires[symbole] = 1
    max_ = max(majoritaires.values())
    index_max = list(majoritaires.values()).index(max_)

    return list(majoritaires.keys())[index_max]

# Question 11: algorithme pour lire les symbole "test_mot"
def Lire_test_mot(repertoire):
    symboles_numpy = {}
    i = 0

    for nom_fichier in os.listdir(repertoire):
        # Lecture de l'image en array
        img = imread(repertoire + "\\" + nom_fichier)

        # Récupère la catégorie et l'index du symmbole
        cat = nom_fichier.split("_")[1]
        indx_symbole = nom_fichier.split("_")[2].split(".")[0]

        # Sauvegarde dans le dictionnaire
        car = symboles[categories.index(cat)][int(indx_symbole)]
        symboles_numpy[i] = (img, car)
        i = i + 1

    return symboles_numpy

# Question 12 : KNN pour retrouver le mot
def KNN_test(symboles_numpy, k, base):
    liste_fichiers = liste_fichiers_repertoire(base)
    caract_ref = lire_donnees_ref(base,liste_fichiers)

    mot = []

    for symbole in symboles_numpy:
        distances = calcul_distances(caract_ref, symboles_numpy[symbole][0])
        voisins = Kvoisins(distances,k)
        mot.append(symbole_majoritaire(voisins))

    return mot


# Question 14: Matrice de confusion
def Matrice_confusion(k,base):
    # Création de la matrice
    matrice = np.zeros(shape=(79,79), dtype=np.int16)


    # Base d'apprentissage
    liste_fichiers = liste_fichiers_repertoire(base)
    caract_ref = lire_donnees_ref(base,liste_fichiers)

    # Construction du dictionnaire symboles_numpy
    symboles_numpy = Lire_test_mot("test_confusion")

    # Recherche des n premiers symboles
    for i,symbole in enumerate(symboles_confusion[0]):
        # Récupère l'index du symbole numpy correspondant
        # au symbole issu de la liste de confusion
        liste_symboles = [symboles_numpy[i][1] for i in range(79)]
        index_symbole = liste_symboles.index(symbole)

        # Applique KNN
        distances = calcul_distances(caract_ref, symboles_numpy[index_symbole][0])
        voisins = Kvoisins(distances,k)
        resultat = symbole_majoritaire(voisins)

        # Complète la matrice
        colonne = symboles_confusion[0].index(resultat)
        matrice[i][colonne] = matrice [i][colonne] + 1

    return matrice

# Question 16 : Taux de réussite
def Taux_de_reussite(matrice):
    # Calcul de la trace
    trace = 0
    for i in range(matrice.shape[0]):
        trace = trace + matrice[i,i]

    # Calcule de la norme
    norme = 0
    for i in range(matrice.shape[0]):
        for j in range(matrice.shape[0]):
            norme = norme + matrice[i,j]

    return float(trace/norme)



# Test Question 2
print("\nTest lecture des fichiers images : on lit 3 images différentes du caractère 'K'")
caract_ref = lire_donnees_ref("base_11x79",liste_images_à_lire)
print(caract_ref.keys())
print(len(caract_ref['K']))
imshow(caract_ref['K'][0], cmap="gray", vmin=0, vmax=1)
#plt.show()


# Tests Question 3 et 4
print("\nTests des calculs de distances")
caract_test = lire_donnees_ref("base_référence",["Zurich Light BT_majuscules18_10.png"])

print(distance(caract_ref['K'][0],caract_ref['K'][0]))
print(distance(caract_test['K'][0],caract_ref['K'][0]))
print(distance(caract_test['K'][0],caract_ref['K'][1]))
print(distance(caract_test['K'][0],caract_ref['K'][2]))

print(distance_np(caract_ref['K'][0],caract_ref['K'][0]))
print(distance_np(caract_test['K'][0],caract_ref['K'][0]))
print(distance_np(caract_test['K'][0],caract_ref['K'][1]))
print(distance_np(caract_test['K'][0],caract_ref['K'][2]))

# Tests Question 5
print("\nTest question 5")
distances = calcul_distances(caract_ref,caract_test['K'][0])
print(distances)

# Test Question 7
print("\nTest question 7")
print(Kvoisins(distances,6))

# Test Question 9
print("\nTest question 9")
voisins = Kvoisins(distances,5)
symbole = symbole_majoritaire(voisins)
print(symbole)

# Test Question 11
symboles_numpy = Lire_test_mot("test_mot")

# Tests Question 12
print("\nTest question 12")
print(KNN_test(symboles_numpy,1,"base_référence"))
print(KNN_test(symboles_numpy,4,"base_11x79"))

# Test question 14
print("\nTest question 14")
matrice = Matrice_confusion(1,"base_référence")
afficher_matrice_confusion(matrice)
matrice = Matrice_confusion(4,"base_11x79")
afficher_matrice_confusion(matrice)

# Test question 16
print("\nTest question 16")
acc = Taux_de_reussite(matrice)
print(acc)

